import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt
from statsmodels.stats.proportion import proportion_confint
import numpy as np
from scipy import stats
import sys

sys.path.insert(1,'/REDACTED/fairness/code/utilities/')
import UncertaintySimulation as unc

def computeExpectedCovariance(data, outcomevarb='audited',covarb='black_ind',conditioningvarb='predicted_prob_black_rd',weightvarb=None):
    estimates = []
    print('updated se version')
    print('startaxpayer_idg len of dataset: ' + str(data.shape[0]))
    data = data.dropna(subset=[outcomevarb, covarb, conditioningvarb])
    print('len of data after dropping nulls: ' + str(data.shape[0]))
    if weightvarb is not None:
        totalwt=data[weightvarb].sum()
        print('total weight is: ' + str(totalwt))
    conditions = data[conditioningvarb].unique()
    print('conditioning varbs: ' + str(conditions))
    for cond in conditions:   
        estimate = {}
        dat = data[data[conditioningvarb]==cond]
        p_cond = len(dat)/len(data)
        covarb_conditional_variance = dat[covarb].var()
        covar_cond = dat[[outcomevarb,covarb]].cov()[outcomevarb][covarb]
        if weightvarb is not None:
            p_cond = (dat[weightvarb]/totalwt).sum()
            reg = smf.wls(outcomevarb+' ~ ' + covarb,data=dat,weights=dat[weightvarb])
        else:
            p_cond = len(dat)/len(data)
            reg = smf.wls(outcomevarb + ' ~ ' + covarb,data=dat)
        reg = reg.fit(cov_type='HC1')
        coef= reg.params[covarb]
        se = reg.bse[covarb]
        estimate['condval'] = cond
        estimate['prob']=p_cond
        estimate['covar_outcome_covarb'] = covar_cond 
        estimate['coef'] = coef
        estimate['se'] = se
        estimate['covarb_conditional_variance'] = covarb_conditional_variance
        estimates.append(estimate)
    estimates = pd.DataFrame(estimates)
    overall_est = (estimates['prob']*estimates['coef']*estimates['covarb_conditional_variance']).sum()
    #overall_se = (estimates['prob']*estimates['se']*estimates['covarb_conditional_variance']).sum()
    overall_se = np.sqrt((estimates['prob']**2*estimates['se']**2*estimates['covarb_conditional_variance']**2).sum())
    tstat= overall_est/overall_se
    dof = len(data)-data[conditioningvarb].nunique()
    pval = stats.t.sf(tstat, dof)
    return overall_est,overall_se,tstat,pval,estimates

def getContaxpayer_idgencyTable(dat,var,varnameT, varnameF, asc=False,blackvar='plurBlack'):
    tabl = dat.groupby([blackvar,var]).aud_no_research_audits.mean().reset_index().pivot(var,'plurBlack')
    tabl.columns = ['Non-Black','Black']
    tabl.index=[varnameF,varnameT]
    tabl = tabl[['Black','Non-Black']].sort_index(ascending=asc)
    tabl['Total'] = [dat[dat[var]==True].aud_no_research_audits.mean(),dat[dat[var]==False].aud_no_research_audits.mean()]
    tot = pd.DataFrame({'Black':[dat[dat[blackvar]==True].aud_no_research_audits.mean()],'Non-Black':[dat[dat[blackvar]==False].aud_no_research_audits.mean()],'Total':[dat.aud_no_research_audits.mean()]})
    tot.index=['Total']
    catted = pd.concat([tabl,tot],axis=0)*100
    catted.to_latex(out+'contaxpayer_idgency_BISG_'+var+'_.tex',formatters=(["{0:.2f}%".format,"{0:.2f}%".format,"{0:.2f}%".format]))
    return(catted)

def getWeightedAuditRate(data,wvB='predicted_prob_black',wvNB='predicted_prob_nonblack'):
    """function to return audit rate for given set of data. weight var can be contaxpayer_iduous or thresholded."""
    ar_b = (data[wvB]*data['aud_no_research_audits']).sum()/(data[wvB].sum())
    ar_nb = (data[wvNB]*data['aud_no_research_audits']).sum()/(data[wvNB].sum())
    return ar_b, ar_nb,ar_b-ar_nb

def getPropensityWeightedAuditRate(data,weightvar='uswgt'):
    return (data[weightvar]*data['audited']).sum()/data[weightvar].sum()

def getReWeightedGroundTruthDisparity(data,weightvar='uswgt'):
    return getPropensityWeightedAuditRate(data[data.black_ind==1],weightvar=weightvar) - getPropensityWeightedAuditRate(data[data.black_ind==0],weightvar=weightvar)


def NCUSGroundTruthTable(data,tablename,out,rounddig=1):
    gts = []
    samples=[]
    ar_bs = []
    ar_nbs =[]
    naive_ests_b = []
    naive_ests_nb= []
    naive_difs = []
    observed_naive_biases = []
    theoretical_bias = []
    filterNames = ['All TP','EIC','Non-EIC','Nonjoint EIC','Joint EIC','Nonjoint Male EIC','Nonjoint Nonmale EIC', 'Nonjoint Male EIC w/Deps','Nonjoint Male EIC no Deps']
    maskTriv = ~data.taxpayer_id.isna()
    maskEIC = data.isEIC==1
    maskFil = data.filing_jointly==True
    maskM = data.isM==True
    maskDeps = data.pos_deps == 1
    filters = [maskTriv, maskEIC,~maskEIC, (maskEIC) & (~maskFil), (maskEIC) & (maskFil), (maskEIC) & (~maskFil) & (maskM), (maskEIC)&(~maskFil)&(~maskM),(maskEIC)&(maskM)&(~maskFil)&(maskDeps),(maskEIC)&(maskM)&(~maskFil)&(~maskDeps)]
    for i in range(len(filterNames)):
        dat = data[filters[i]]
        naiveB = 100*naiveHajek(dat.copy(),featureVar='audited',inverseProbWt='uswgt',probVar='predicted_prob_black')
        naiveNB = 100*naiveHajek(dat.copy(),featureVar='audited',inverseProbWt='uswgt',probVar='predicted_prob_nonblack')
        gt=100*getReWeightedGroundTruthDisparity(dat,weightvar='uswgt')
        naiveDif = naiveB-naiveNB
        ar_b = 100*getPropensityWeightedAuditRate(dat[dat.black_ind==1],'uswgt')
        ar_nb = 100*getPropensityWeightedAuditRate(dat[dat.black_ind==0],'uswgt')
        ar_bs.append(ar_b)
        ar_nbs.append(ar_nb)
        gts.append(gt)
        naive_ests_b.append(naiveB)
        naive_ests_nb.append(naiveNB)
        naive_difs.append(naiveDif)
        observed_naive_biases.append(naiveDif-gt)
        theoretical_bias.append(100*computeDebiasTerm(dat))
        samples.append(filterNames[i])
    results = pd.DataFrame({'Population':samples,'NCUS Naive Hajek B':naive_ests_b, 'NCUS Naive Hajek NB':naive_ests_nb,'NCUS Naive Dif':naive_difs, 'NCUS GT AR Black':ar_bs,'NCUS GT AR Nonblack':ar_nbs,'NCUS GT Disparity': gts,'Observed Naive Bias':observed_naive_biases,'Asypmtotic Bias':theoretical_bias})
    results.round(rounddig).to_latex(out+tablename+'.tex',index=False)
    results.round(rounddig).to_csv(out+tablename+'.csv',index=False)
    return results

def getAllMeasurementsTable(data,tablename,out,rounddig=1,have_ground_truth=False):
    """function to create the main table"""
    res_full = ['All TP']
    rows = []
    filterNames = ['All TP','EIC','Non-EIC','Nonjoint EIC','Joint EIC','Nonjoint Male EIC','Nonjoint Nonmale EIC', 'Nonjoint Male EIC w/Deps','Nonjoint Male EIC no Deps']
    maskTriv = ~data.taxpayer_id.isna()
    maskEIC = data.isEIC==1
    maskFil = data.filing_jointly==True
    maskM = data.isM==True
    maskDeps = data.pos_deps == 1
    filters = [maskTriv, maskEIC,~maskEIC, (maskEIC) & (~maskFil), (maskEIC) & (maskFil), (maskEIC) & (~maskFil) & (maskM), (maskEIC)&(~maskFil)&(~maskM),(maskEIC)&(maskM)&(~maskFil)&(maskDeps),(maskEIC)&(maskM)&(~maskFil)&(~maskDeps)]
    varlist = [('predicted_prob_black','predicted_prob_nonblack'),('plurBlack','plurNonblack'),('black50','nonblack50'),('black90','nonblack90')]
    colnlist = [('Weight B','Weight NB'), ('Plurality B','Plurality NB'),('Threshold 50 B','Threshold 50 NB'),('Threshold 90 B','Threshold 90 NB')]
    colsuper = ['Weight','Plurality','Threshold 50 pct','Threshold 90 pct']
    if have_ground_truth:
        varlist.append(('black_ind','nonblack_ind'))
        colnlist.append(('Ground Truth B', 'Ground Truth NB'))
        colsuper.append('Ground Truth')
    repcols = ['Black','Nonblack','Dif']
    for i in range(len(filterNames)):
        row = [filterNames[i]]
        rowdata = []
        print(filterNames[i])
        for j in range(len(varlist)):
            ar_b, ar_nb, dif = getWeightedAuditRate(data[filters[i]],wvB=varlist[j][0],wvNB=varlist[j][1])
            rowdata.extend([100*ar_b,100*ar_nb,100*dif])
            print(varlist[j])
        row.extend(rowdata)
        rows.append(row)
    results = pd.DataFrame.from_records(rows)
    tupes = [[(colsuper[i],repcols[j]) for j in range(3) ] for i in range(len(colsuper))]
    tupes = [item for sublist in tupes for item in sublist]
    tupes.insert(0,('','Dataset'))
    columidx = pd.MultiIndex.from_tuples(tupes)
    results.columns=columidx
    results.round(rounddig).to_latex(out+tablename+'.tex',index=False)
    results.round(rounddig).to_csv(out+tablename+'.csv',index=False)
    return results

def makeEstimatorTable(data, out, save=True, wvB='predicted_prob_black', wvNB='predicted_prob_nonblack', estimator='probabilistic'):
    res_full = ['All TP']
    rows = []
    filterNames = ['All TP','EIC','Non-EIC','Nonjoint EIC','Joint EIC','Nonjoint Male EIC','Nonjoint Nonmale EIC', 'Nonjoint Male EIC w/Deps','Nonjoint Male EIC no Deps']
    maskTriv = ~data.taxpayer_id.isna()
    maskEIC = data.isEIC==1
    maskFil = data.filing_jointly==True
    maskM = data.isM==True
    maskDeps = data.pos_deps == 1
    filters = [maskTriv, maskEIC,~maskEIC, (maskEIC) & (~maskFil), (maskEIC) & (maskFil), (maskEIC) & (~maskFil) & (maskM), (maskEIC)&(~maskFil)&(~maskM),(maskEIC)&(maskM)&(~maskFil)&(maskDeps),(maskEIC)&(maskM)&(~maskFil)&(~maskDeps)]
    varlist = ['Black Audit Rate','Non-Black Audit Rate','Additive Disparity','Relative Disparity', 'Additive Disparity Standard Error']
    rows = []

    for i in range(len(filterNames)):

        if estimator == 'probabilistic':
            ar_b, ar_nb, dif = getWeightedAuditRate(data[filters[i]], wvB=wvB, wvNB=wvNB)
            rel_dif = ar_b/ar_nb if ar_nb>0 else np.nan
            seChen, seReg = unc.getSEs(data[filters[i]], pbvarb=wvB, outcome='aud_no_research_audits', wvarb=None)

            row = [filterNames[i], 100*ar_b, 100*ar_nb, 100*dif, rel_dif, seChen]
            rows.append(row)
            
        elif estimator == 'linear':
            regcoef, regconst = regressionEstimator(data[filters[i]], xcol='predicted_prob_black', ycol='aud_no_research_audits', wcol='unitwt')
            ar_b = regcoef + regconst
            ar_nb = regconst
            dif = ar_b - ar_nb
            rel_dif = ar_b/ar_nb if ar_nb>0 else np.nan
            seChen, seReg = unc.getSEs(data[filters[i]], pbvarb=wvB, outcome='aud_no_research_audits', wvarb=None)
            row = [filterNames[i], 100*ar_b, 100*ar_nb, 100*dif, rel_dif, seReg]
            rows.append(row)

    results = pd.DataFrame.from_records(rows)
    results.columns = ['Population'] + varlist

    if save:
        if estimator == 'probabilistic':
            results.round(2).to_latex(out+'weighted_est_fullpop_proba.tex')
        elif estimator == 'linear':
            results.round(2).to_latex(out+'weighted_est_fullpop_lin.tex')
    return results
    
def naive_hajek_all(data,wgtvar):
    return naiveHajek(data,inverseProbWt=wgtvar),naiveHajek(data,inverseProbWt = wgtvar,probVar='predicted_prob_nonblack'), naiveHajek(data,inverseProbWt=wgtvar,probVar='predicted_prob_black')-naiveHajek(data,inverseProbWt=wgtvar,probVar='predicted_prob_nonblack')

def naive_hajek_all_gt(data,wgtvar):
    return naiveHajek(data,inverseProbWt=wgtvar,probVar='black_ind'),naiveHajek(data,inverseProbWt = wgtvar,probVar='nonblack_ind'), naiveHajek(data,inverseProbWt=wgtvar,probVar='black_ind')-naiveHajek(data,inverseProbWt=wgtvar,probVar='nonblack_ind')

def makeNCTable(data,out,rounddig=2):
    rows = []
    filterNames = ['All TP','EIC','Non-EIC','Nonjoint EIC','Joint EIC','Nonjoint Male EIC','Nonjoint Nonmale EIC', 'Nonjoint Male EIC w/Deps','Nonjoint Male EIC no Deps']
    maskTriv = ~data.taxpayer_id.isna()
    maskEIC = data.isEIC==1
    maskFil = data.filing_jointly==True
    maskM = data.isM==True
    maskDeps = data.pos_deps == 1
    filters = [maskTriv, maskEIC,~maskEIC, (maskEIC) & (~maskFil), (maskEIC) & (maskFil), (maskEIC) & (~maskFil) & (maskM), (maskEIC)&(~maskFil)&(~maskM),(maskEIC)&(maskM)&(~maskFil)&(maskDeps),(maskEIC)&(maskM)&(~maskFil)&(~maskDeps)]
    varlist = ['Black Audit Rate','Non-Black Audit Rate','Additive Disparity','Relative Disparity','Reweighted Black Audit Rate','Reweighted Non-Black Audit Rate','Reweighted Additive Disparity','Reweighted Relative Disparity']
    rows = []
    for i in range(len(filterNames)):
        ar_b,ar_nb,dif = groundTruth(data[filters[i]],indvb='black_ind',indvnb='nonblack_ind')
        arb_rw,arb_nb_rw, dif_rb = naive_hajek_all_gt(data[filters[i]].copy(),'uswgt')
        rel_dif_gt = ar_b/ar_nb if ar_nb>0 else np.nan
        rel_dif_rwt = arb_rw/arb_nb_rw if arb_nb_rw>0 else np.nan
        row = [filterNames[i],100*ar_b,100*ar_nb,100*dif,rel_dif_gt,arb_rw*100,arb_nb_rw*100,dif_rb*100,rel_dif_rwt]
        rows.append(row)
    results = pd.DataFrame.from_records(rows)
    results.columns = ['Population'] + varlist
    results.round(rounddig).to_latex(out+'nc_table.tex')
    results[['Population','Black Audit Rate','Non-Black Audit Rate', 'Additive Disparity','Relative Disparity']].round(rounddig).to_latex(out+'nc_gt_table.tex')
    results[['Population','Reweighted Black Audit Rate','Reweighted Non-Black Audit Rate','Reweighted Additive Disparity','Reweighted Relative Disparity']].round(rounddig).to_latex(out+'nc_rwt_table.tex')
    return results

def regressionEstimator(data,xcol='predicted_prob_black',ycol='audited',wcol='unitwt'):
    X = data[[xcol]]
    X = sm.add_constant(X)
    Y = data[ycol]
    wls_model = sm.WLS(Y,X, weights=data[wcol])
    results = wls_model.fit()
    params = results.params
    linex = [data[xcol].min(),data[xcol].max()]
    liney = [data[xcol].min()*params[xcol]+params['const'], data[xcol].max()*params[xcol]+params['const']]
    return params[xcol],params['const']

def regressionEstimatorTables(data,tablename,out,rounddig=2,xcol='predicted_prob_black',ycol='audited',wcol='unitwt'):
    res_full = ['All TP']
    rows = []
    filterNames = ['All TP','EIC','Non-EIC','Nonjoint EIC','Joint EIC','Nonjoint Male EIC','Nonjoint Nonmale EIC', 'Nonjoint Male EIC w/Deps','Nonjoint Male EIC no Deps']
    maskTriv = ~data.taxpayer_id.isna()
    maskEIC = data.isEIC==1
    maskFil = data.filing_jointly==True
    maskM = data.isM==True
    maskDeps = data.pos_deps == 1
    filters = [maskTriv, maskEIC,~maskEIC, (maskEIC) & (~maskFil), (maskEIC) & (maskFil), (maskEIC) & (~maskFil) & (maskM), (maskEIC)&(~maskFil)&(~maskM),(maskEIC)&(maskM)&(~maskFil)&(maskDeps),(maskEIC)&(maskM)&(~maskFil)&(~maskDeps)]
    coefs= []
    datanames= []
    constants =[]
    for i in range(len(filters)):
        dat = data[filters[i]]
        regcoef,regconst = regressionEstimator(dat,xcol,ycol,wcol)
        coefs.append(regcoef)
        constants.append(regconst)
        datanames.append(filterNames[i])
    results = pd.DataFrame({'Series':datanames,'Coefficient':coefs,'Intercept':constants})
    results['Coefficient'] = results['Coefficient']*100
    results['Intercept'] = results['Intercept']*100
    results.columns = ['Population'] + varlist
    results.round(rounddig).to_latex(out+tablename+'.tex',index=False)
    results.round(rounddig).to_csv(out+tablename+'.csv',index=False)
    return(results)

def simpleChen(data,probwt,estimand):
     return (data[estimand]*data[probwt]).sum()/data[probwt].sum()

def simpleChenDisparity(data,probB='predicted_prob_black',probNB='predicted_prob_nonblack',estimand='audited'):
    return simpleChen(data,probB,estimand)-simpleChen(data,probNB,estimand)


def stratifiedChen(data,prob_var,stratvar,wgtvar):
    strata = data[stratvar].unique()
    strata_wt = {s:data[data[stratvar]==s][wgtvar].sum() for s in strata}
    ests = {s:simpleChen(data[data[stratvar]==s], prob_var,'audited') for s in strata}
    return sum([strata_wt[s]*ests[s] for s in strata])/sum(list(strata_wt.values()))

def stratifiedChenDisparity(data,stratvar,wgtvar):
    return stratifiedChen(data,'predicted_prob_black',stratvar,wgtvar)-stratifiedChen(data,'predicted_prob_nonblack',stratvar,wgtvar)


def stratifiedHajekChen(data,prob_var,stratvar,wgtvar):
    """ should be the same as above. implemented to make sure."""
    strata=data[stratvar].unique()
    stratares = {s: stratifiedChen(data[data[stratvar]==s],prob_var,stratvar,wgtvar) for s in strata}
    data['est'] = data[stratvar].map(stratares)
    return simpleChen(data,wgtvar,'est')

def multlyStratifiedChen(data,prob_var,stratvar,wgtvar):
    strata = data[stratvar].unique()
    strata_wt = {s:(data[data[stratvar]==s][wgtvar]*data[data[stratvar]==s][prob_var]).sum() for s in strata}
    ests = {s:simpleChen(data[data[stratvar]==s], prob_var,'audited') for s in strata}
    return sum([strata_wt[s]*ests[s] for s in strata])/sum(list(strata_wt.values()))

def multlyStratifiedChenDisparity(data,stratvar,wgtvar):
    return multlyStratifiedChen(data,'predicted_prob_black',stratvar,wgtvar)-multlyStratifiedChen(data,'predicted_prob_nonblack',stratvar,wgtvar)


def computeBias(data,ind_var='black_ind',p_var='p_black_rd'):
    cov_by_p = [data[data[p_var]==p][['audited',ind_var]].cov()['audited'][ind_var] for p in data[p_var].unique()]
    wt_by_p = [len(data[data[p_var]==p])/len(data) for p in data[p_var].unique()]
    return -1*sum([cov_by_p[i]*wt_by_p[i] for i in range(len(cov_by_p))])/(data[ind_var].sum()/len(data))
def computeBiasGeneral(data,voi,ind_var='black_ind',p_var='p_black_rd',wtvar='uswgt'):
    cov_by_p = [data[data[p_var]==p][[voi,ind_var]].cov()[voi][ind_var] for p in data[p_var].unique()]
    wt_by_p = [len(data[data[p_var]==p])/len(data) for p in data[p_var].unique()]
    return -1*sum([cov_by_p[i]*wt_by_p[i] for i in range(len(cov_by_p))])/(data[ind_var].sum()/len(data))

def computeDebiasTerm(data,prob_b='p_black_rd',prob_nb='p_nonblack_rd'):
    return computeBias(data,'black_ind',prob_b)-computeBias(data,'nonblack_ind',prob_nb)


def computeBiasWeighted(data,ind_var='black_ind',p_var='p_black_rd',wtvar='uswgt'):
    cov_dict= {p:data[data[p_var]==p][['audited',ind_var]].cov()['audited'][ind_var] for p in data[p_var].unique()}
    data['cov_est'] = data[p_var].map(cov_dict)
    return -1*hajek(data,'cov_est',wtvar)/hajek(data,ind_var,wtvar)
def computeBiasGeneralWeighted(data,voi,ind_var='black_ind',p_var='p_black_rd',wtvar='uswgt'):
    cov_dict = {p:data[data[p_var]==p][[voi,ind_var]].cov()[voi][ind_var] for p in data[p_var].unique()}
    data['cov_est_'+voi] = data[p_var].map(cov_dict)
    return -1*hajek(data,'cov_est_'+voi,wtvar)/hajek(data,ind_var,wtvar)

def getBiasEstimateTables(data,tablename,out,rounddig=4,weighted=False,wgtvar='uswgt'):
    rows = []
    filterNames = ['All TP','EIC','Non-EIC','Nonjoint EIC','Joint EIC','Nonjoint Male EIC','Nonjoint Nonmale EIC', 'Nonjoint Male EIC w/Deps','Nonjoint Male EIC no Deps']
    maskTriv = ~data.taxpayer_id.isna()
    maskEIC = data.isEIC==1
    maskFil = data.filing_jointly==True
    maskM = data.isM==True
    maskDeps = data.pos_deps == 1
    filters = [maskTriv, maskEIC,~maskEIC, (maskEIC) & (~maskFil), (maskEIC) & (maskFil), (maskEIC) & (~maskFil) & (maskM), (maskEIC)&(~maskFil)&(~maskM),(maskEIC)&(maskM)&(~maskFil)&(maskDeps),(maskEIC)&(maskM)&(~maskFil)&(~maskDeps)]
    bias_bs= []
    bias_nbs = []
    difs = []
    datanames= []
    n_obss = []
    for i in range(len(filters)):
        dat = data[filters[i]]
        n_obs = len(dat)
        if weighted is False:
            bias_b = computeBias(dat,'black_ind','p_black_rd')
            bias_nb = computeBias(dat,'nonblack_ind','p_nonblack_rd')
        else:
            bias_b = computeBiasWeighted(dat,'black_ind','p_black_rd',wgtvar)
            bias_nb = computeBiasWeighted(dat,'nonblack_ind','p_nonblack_rd',wgtvar)
        dif = bias_b-bias_nb
        bias_bs.append(bias_b)
        bias_nbs.append(bias_nb)
        difs.append(dif)
        datanames.append(filterNames[i])
        n_obss.append(n_obs)
    results = pd.DataFrame({'Series':datanames,'Asymptotic Bias (Black)':bias_bs,'Asymptotic Bias (Nonblack)':bias_nbs,'Asymptotic Bias (Dif)':difs,'N Obs':n_obss})
    results['Asymptotic Bias (Black)'] = results['Asymptotic Bias (Black)']*100
    results['Asymptotic Bias (Nonblack)'] = results['Asymptotic Bias (Nonblack)']*100
    results['Asymptotic Bias (Dif)'] = results['Asymptotic Bias (Dif)']*100
    results.round(rounddig).to_latex(out+tablename+'.tex',index=False)
    results.round(rounddig).to_csv(out+tablename+'.csv',index=False)
    return(results)

def getBiasEstimateTablesGeneral(data,voi,tablename,out,rounddig=4,weighted=False,wgtvar='uswgt'):
    rows = []
    filterNames = ['All TP','EIC','Non-EIC','Nonjoint EIC','Joint EIC','Nonjoint Male EIC','Nonjoint Nonmale EIC', 'Nonjoint Male EIC w/Deps','Nonjoint Male EIC no Deps']
    maskTriv = ~data.taxpayer_id.isna()
    maskEIC = data.isEIC==1
    maskFil = data.filing_jointly==True
    maskM = data.isM==True
    maskDeps = data.pos_deps == 1
    filters = [maskTriv, maskEIC,~maskEIC, (maskEIC) & (~maskFil), (maskEIC) & (maskFil), (maskEIC) & (~maskFil) & (maskM), (maskEIC)&(~maskFil)&(~maskM),(maskEIC)&(maskM)&(~maskFil)&(maskDeps),(maskEIC)&(maskM)&(~maskFil)&(~maskDeps)]
    bias_bs= []
    bias_nbs = []
    difs = []
    datanames= []
    n_obss = []
    for i in range(len(filters)):
        dat = data[filters[i]]
        n_obs = len(dat)
        if weighted is False:
            bias_b = computeBiasGeneral(dat,voi, 'black_ind','p_black_rd')
            bias_nb = computeBiasGeneral(dat,voi,'nonblack_ind','p_nonblack_rd')
        else:
            bias_b = computeBiasGeneralWeighted(dat,voi,'black_ind','p_black_rd',wgtvar)
            bias_nb = computeBiasGeneralWeighted(dat,voi,'nonblack_ind','p_nonblack_rd',wgtvar)
        dif = bias_b-bias_nb
        bias_bs.append(bias_b)
        bias_nbs.append(bias_nb)
        difs.append(dif)
        datanames.append(filterNames[i])
        n_obss.append(n_obs)
    results = pd.DataFrame({'Series':datanames,'Asymptotic Bias (Black)':bias_bs,'Asymptotic Bias (Nonblack)':bias_nbs,'Asymptotic Bias (Dif)':difs,'N Obs':n_obss})
    results['Asymptotic Bias (Black)'] = results['Asymptotic Bias (Black)']*100
    results['Asymptotic Bias (Nonblack)'] = results['Asymptotic Bias (Nonblack)']*100
    results['Asymptotic Bias (Dif)'] = results['Asymptotic Bias (Dif)']*100
    results.round(rounddig).to_latex(out+tablename+'.tex',index=False)
    results.round(rounddig).to_csv(out+tablename+'.csv',index=False)
    return(results)



def getFrechetDisparities(data,priorBlack = None, priorNonblack =None):
    if priorBlack is None:
        pB = data.black_ind.sum()/len(data)
        pNB = data.nonblack_ind.sum()/len(data)
    else: 
        pB = priorBlack
        pNB = priorNonblack
    data['p_audit_alpha'] = data.groupby('p_black_rd').audited.mean()
    data['p_noaudit_alpha'] = 1- data['p_audit_alpha']
    data['p_outcome_alpha'] = data['p_noaudit_alpha']
    data.loc[data['audited']==True,'p_outcome_alpha'] = data.loc[data['audited']==True,'p_audit_alpha']
    data['mu_lower_B'] = data.apply(lambda x: mu_lower(x['p_black_rd'],x['p_outcome_alpha']),axis=1)
    data['mu_upper_B'] = data.apply(lambda x: mu_upper(x['p_black_rd'],x['p_outcome_alpha']),axis=1)
    data['mu_lower_NB'] = data.apply(lambda x: mu_lower(1-x['p_nonblack_rd'],x['p_outcome_alpha']),axis=1)
    data['mu_upper_NB'] = data.apply(lambda x: mu_upper(1-x['p_nonblack_rd'],x['p_outcome_alpha']),axis=1)
    mu_B_L = (data['mu_lower_B']*data['audited']).mean()/pB
    mu_B_U = (data['mu_upper_B']*data['audited']).mean()/pB
    mu_NB_L = (data['mu_lower_NB']*data['audited']).mean()/pNB
    mu_NB_U = (data['mu_upper_NB']*data['audited']).mean()/pNB
    return (mu_B_L-mu_NB_U,mu_B_U-mu_NB_L)

def mu_lower(prob_race,prob_out_z):
    """note slightly different arguments as in paper - we don't actually ned yhat here, we just need P(Yhat=yhat), which we can just pass based on outcome)"""
    return max (0, (1+ (prob_race)-1)/(prob_out_z))

def mu_upper(prob_race,prob_out_z):
    return min(1, prob_race/prob_out_z)

def getFrechetDisparityTable(data,out,tablename,rounddig=2):
    res_full = ['All TP']
    rows = []
    filterNames = ['All TP','EIC','Non-EIC','Nonjoint EIC','Joint EIC','Nonjoint Male EIC','Nonjoint Nonmale EIC', 'Nonjoint Male EIC w/Deps','Nonjoint Male EIC no Deps']
    maskTriv = ~data.taxpayer_id.isna()
    maskEIC = data.isEIC==1
    maskFil = data.filing_jointly==True
    maskM = data.isM==True
    maskDeps = data.pos_deps == 1
    filters = [maskTriv, maskEIC,~maskEIC, (maskEIC) & (~maskFil), (maskEIC) & (maskFil), (maskEIC) & (~maskFil) & (maskM), (maskEIC)&(~maskFil)&(~maskM),(maskEIC)&(maskM)&(~maskFil)&(maskDeps),(maskEIC)&(maskM)&(~maskFil)&(~maskDeps)]
    lowers = []
    uppers = []
    truths = []
    datanames= []
    n_obss = []
    for i in range(len(filters)):
        dat = data[filters[i]] 
        lower,upper = getFrechetDisparities(dat.copy())
        n_obs = len(dat)
        dataname = filterNames[i]
        datanames.append(dataname)
        n_obss.append(n_obs)  
        lowers.append(lower)
        uppers.append(upper)
        truth = dat[dat.black_ind==1].audited.mean()-dat[dat.black_ind==0].audited.mean()
        truths.append(truth)
    results = pd.DataFrame({'Series':datanames,'Lower Bound': lowers,'Upper Bound': uppers, 'Ground Truth':truths, 'N Obs':n_obss})
    results['Lower Bound'] = results['Lower Bound']*100
    results['Upper Bound'] = results['Upper Bound']*100
    results['Ground Truth'] = results['Ground Truth']* 100
    results.round(rounddig).to_latex(out+tablename+'.tex',index=False)
    results.round(rounddig).to_csv(out+tablename+'.csv',index=False)
    return results


def hajek(data,featureVar='audited',weightVar='base_wgt'):
    return (data[featureVar]*data[weightVar]).sum()/(data[weightVar].sum())

def naiveHajek(data,featureVar='audited',inverseProbWt='uswgt',probVar='predicted_prob_black'):
    data['ipw_prob'] = data[inverseProbWt]*data[probVar]
    return (data['ipw_prob']*data[featureVar]).sum()/(data['ipw_prob'].sum())

def naiveHajekDisparity(data,featureVar='audited',inverseProbWt='uswgt'):
    return naiveHajek(data,featureVar=featureVar,inverseProbWt=inverseProbWt,probVar='predicted_prob_black')-naiveHajek(data,featureVar=featureVar,inverseProbWt=inverseProbWt, probVar='predicted_prob_nonblack')

def directCorrection(data):
    pass

def groundTruth(data,indvb,indvnb):
    return data[data[indvb]==1].audited.mean(), data[data[indvnb]==1].audited.mean(), data[data[indvb]==1].audited.mean()-data[data[indvnb]==1].audited.mean()
def getConvergence(data,n_sample_sizes=10,n_samples=20,bucketsize=10,max_frac=0.1):
    if bucketsize==10:
        p_b_rd, p_nb_rd = 'prob_black_rd_10','prob_nonblack_rd_10'
    elif bucketsize==5:
        p_b_rd, p_nb_rd = 'prob_black_rd_05','prob_nonblack_rd_05'
    else:
        p_b_rd, p_nb_rd = 'prob_black_rd','prob_nonblack_rd'
    sampleSizes = [ int((i+1)/n_sample_sizes*max_frac*len(data)) for i in range(n_sample_sizes)]
    runs = []
    samplesizes = []
    debiases = []
    sampleFracs = []  
    for sampSize in sampleSizes:
        for run in range(n_samples):
            print(sampSize,run)
            subsample = data.sample(sampSize,replace=True)
            debias=computeDebiasTerm(subsample,p_b_rd,p_nb_rd)
            debiases.append(debias)
            runs.append(run)
            samplesizes.append(sampSize)
            sampleFracs.append(sampSize/len(data))
    return pd.DataFrame({'SampleSize':samplesizes,'SampleFrac':sampleFracs,'Run':runs,'Debias':debiases})



def convergenceForAllDatasets(data, n_sample_sizes=10,n_samples=10,bucketsize=10,max_frac=0.1):
    filterNames = ['All TP','EIC','Non-EIC','Nonjoint EIC','Joint EIC','Nonjoint Male EIC','Nonjoint Nonmale EIC', 'Nonjoint Male EIC w/Deps','Nonjoint Male EIC no Deps']
    maskTriv = ~data.taxpayer_id.isna()
    maskEIC = data.isEIC==1
    maskFil = data.filing_jointly==True
    maskM = data.isM==True
    maskDeps = data.pos_deps == 1
    filters = [maskTriv, maskEIC,~maskEIC, (maskEIC) & (~maskFil), (maskEIC) & (maskFil), (maskEIC) & (~maskFil) & (maskM), (maskEIC)&(~maskFil)&(~maskM),(maskEIC)&(maskM)&(~maskFil)&(maskDeps),(maskEIC)&(maskM)&(~maskFil)&(~maskDeps)]
    results = {}
    for i in range(len(filterNames)):
        dat = data[filters[i]]
        res=getConvergence(data,n_sample_sizes,n_samples,bucketsize,max_frac)
        results[filterNames[i]] = res
    return results

def makeConvergencePlot(results,out):
    plt.plot(figsize=(10,8))
    for i in range(len(results)):
        name = list(results.keys())[i]
        
        dat = results[name].groupby('SampleFrac').Debias.mean().reset_index()
        plt.plot(dat['SampleFrac'],dat['Debias'],label=name)
    plt.legend(title='Sample')
    plt.legend(title='Fraction of Sample')
    plt.xlabel('Sample Fraction')
    plt.ylabel('Avg Debiasing Term')
    plt.savefig(out+'debias_samplefrac_convg.png')
    plt.close('all')
    plt.plot(figsize=(10,8))
    for i in range(len(results)):
        name = list(results.keys())[i]   
        dat = results[name].groupby('SampleSize').Debias.mean().reset_index()
        plt.plot(dat['SampleSize'],dat['Debias'],label=name)
    plt.legend(title='Sample')
    plt.legend(title='Fraction of Sample')
    plt.xlabel('Sample Size')
    plt.ylabel('Avg Debiasing Term')
    plt.savefig(out+'debias_samplesize_convg.png')




def computeTomsStats(ncdata,fulldata,out,wgtvar='uswgt_sparseA',stratvar='stratum_sparseA'):
    if 'eic_ind_x' in ncdata.columns:
        ncdata = ncdata.rename({'eic_ind_x':'eic_ind'},axis=1)
    if 'predicted_prob_nonblack' not in fulldata.columns:
        fulldata['predicted_prob_nonblack'] = 1-fulldata['predicted_prob_black']
    if 'isM' not in fulldata.columns:
        fulldata['isM'] = fulldata.gender_ind=='M'
    if 'filing_jointly' not in fulldata.columns:
        fulldata['filing_jointly'] =fulldata['filing_status']==2
    ncMaskM = ncdata.isM==True
    ncMaskF = ncdata.filing_jointly==True
    ncMaskE = ncdata.isEIC==True
    ncMaskD = ncdata.pos_deps==True
    fuMaskM = fulldata.isM==True
    fuMaskF = fulldata.filing_jointly==True
    fuMaskE = fulldata.isEIC==True
    fuMaskD = fulldata.pos_deps == True
    
    groups = ['all TP', 'nonEIC','EIC','EIC Joint','EIC Nonjoint', 'EIC Nonjoint Nonmale','EIC Nonjoint Male','EIC Nonjoint Male No Deps','EIC Nonjoint Male Deps']
    ncmasks = [~ncdata.taxpayer_id.isna(),~ncMaskE, ncMaskE, ncMaskE & ncMaskF, ncMaskE & ~ncMaskF, ncMaskE & ~ncMaskF & ~ncMaskM, ncMaskE & ~ncMaskF & ncMaskM,ncMaskE & ~ncMaskF & ncMaskM & ~ncMaskD,ncMaskE & ~ncMaskF & ncMaskM & ncMaskD]
    fullmasks = [~fulldata.taxpayer_id.isna(),~fuMaskE, fuMaskE, fuMaskE & fuMaskF,fuMaskE & ~fuMaskF, fuMaskE & ~fuMaskF & ~fuMaskM, fuMaskE & ~fuMaskF & fuMaskM,fuMaskE & ~fuMaskF & fuMaskM & ~fuMaskD,fuMaskE & ~fuMaskF & fuMaskM & fuMaskD]
    rows = []
    for j in range(len(groups)):
        group = groups[j]
        print(group)
        ncmask=ncmasks[j]
        fullmask=fullmasks[j]
        fulldat=fulldata[fullmask].copy()
        ncdat=ncdata[ncmask].copy()
        mean_prb_fullus = fulldat['predicted_prob_black'].mean()
        mean_pr_b_wgt = simpleChen(ncdat,wgtvar,'predicted_prob_black')
        mean_b_wgt = simpleChen(ncdat,wgtvar,'black_ind')
        mean_b_nc = ncdat['black_ind'].mean()
        audit_rate_full = fulldat['audited'].mean()
        audit_rate_nc_wt = simpleChen(ncdat,wgtvar,'audited')
        audit_rate_nc = ncdat['audited'].mean()
        nc_rwt_gt_ar = getPropensityWeightedAuditRate(ncdat,wgtvar)
        disparity_fullchen = simpleChenDisparity(fulldat)
        disparity_naive = naiveHajekDisparity(ncdat,inverseProbWt=wgtvar)
        strat1 = stratifiedChenDisparity(ncdat,wgtvar=wgtvar,stratvar=stratvar)
        strat2 = multlyStratifiedChenDisparity(ncdat,wgtvar=wgtvar,stratvar=stratvar)
        nc_rwt_gt_disp = getReWeightedGroundTruthDisparity(ncdat,wgtvar)
        bias_fullchen = nc_rwt_gt_disp-disparity_fullchen
        bias_naive = nc_rwt_gt_disp-disparity_naive
        bias_strat1 = strat1 - disparity_fullchen
        bias_strat2 = strat2 - disparity_fullchen
        estimatedBiasChen = computeBiasWeighted(ncdat,wtvar=wgtvar)
        observedBiasNaive = disparity_naive-nc_rwt_gt_disp
        observedBiasStrat2 = strat2 - nc_rwt_gt_disp
        biasCorrectedStrat2 = strat2-estimatedBiasChen
        remainingBias = biasCorrectedStrat2 - nc_rwt_gt_disp
        nc_disp = ncdat[ncdat.black_ind==1].audited.mean()-ncdat[ncdat.black_ind==0].audited.mean()
        nc_disp_rel = ncdat[ncdat.black_ind==1].audited.mean()/ncdat[ncdat.black_ind==0].audited.mean() if ncdat[ncdat.black_ind==0].audited.mean()>0 else -1
        row = [mean_prb_fullus,mean_pr_b_wgt,mean_b_wgt,audit_rate_full,audit_rate_nc_wt,nc_rwt_gt_ar,disparity_fullchen,disparity_naive,strat1,strat2,nc_rwt_gt_disp,bias_fullchen,bias_naive,bias_strat1,bias_strat2, estimatedBiasChen,biasCorrectedStrat2,remainingBias,mean_b_nc,audit_rate_nc,nc_disp,nc_disp_rel]
        rows.append(row)        
    results = pd.DataFrame.from_records(rows).transpose()
    results.columns = groups
    results['Description'] = ['Full US Mean Pred.Black', 'NC->US Mean PredBlack','NC->US Mean True Black','Full US Audit Rate','NC->US Audit Rate','NC->US Audit Rate','Disparity Full Pop chen','Disparity NC->US Naive','Disparity NC->US Strat sampleweight','Disparity NC->US Strat sampleraceweight','NC->US ground truth disparity','Observed Bias - full pop Chen','Observed Bias - Naive','Observed Bias - Strat 1', 'Observed Bias - Strat 2','Estimated Bias (reweighted)', 'Bias Corrected Strat 2','Remaining Bias Post Correction','Raw NC Pr Black','Raw NC Audit Rate', 'Raw NC Add. Disp.','Raw NC Rel. Disp.']
    results.to_csv(out+'toms_table.csv')
    results.to_latex(out+'toms_table.csv')


def chenVarianceBinaryIID(data,weight,estimand):
    p = simpleChen(data,weight,estimand)
    q = 1-p 
    return p*q*(data[weight]**2).sum()/(data[weight].sum())**2

def compareRaceModels(data, bifplus_b,bifplus_nb):
    groups = ['all TP', 'nonEIC','EIC','EIC Joint','EIC Nonjoint', 'EIC Nonjoint Nonmale','EIC Nonjoint Male','EIC Nonjoint Male No Deps','EIC Nonjoint Male Deps']
    MaskM = data.isM==True
    MaskF = data.filing_jointly==True
    MaskE = data.isEIC==True
    MaskD = data.pos_deps==True 
    masks = [~data.taxpayer_id.isna(),~MaskE, MaskE, MaskE & MaskF, MaskE & ~MaskF, MaskE & ~MaskF & ~MaskM, MaskE & ~MaskF & MaskM,MaskE & ~MaskF & MaskM & ~MaskD,MaskE & ~MaskF & MaskM & MaskD]
    x=[]
    nb_bifsg = []
    b_bifsg = []
    nb_bifsg_plus = []
    b_bifsg_plus = []
    b_truth = []
    nb_truth = []
    dif_bifsg = []
    dif_bifsg_plus = []
    dif_truth = []
    for j in range(len(groups)):
        dat = data[masks[j]]
        res_bifsg = getWeightedAuditRate(data[masks[j]],'predicted_prob_black','predicted_prob_nonblack')
        res_bifsgp = getWeightedAuditRate(data[masks[j]],bifplus_b,bifplus_nb)
        arb_tru = dat[dat.black_ind==1].audited.mean()
        arnb_tru = dat[dat.black_ind==0].audited.mean()
        dif_tru = arb_tru-arnb_tru
        b_bifsg.append(res_bifsg[0])
        nb_bifsg.append(res_bifsg[1])
        dif_bifsg.append(res_bifsg[2])
        print(res_bifsgp)
        b_bifsg_plus.append(res_bifsgp[0])
        nb_bifsg_plus.append(res_bifsgp[1])
        dif_bifsg_plus.append(res_bifsgp[2])
        b_truth.append(arb_tru)
        nb_truth.append(arnb_tru)
        dif_truth.append(dif_tru)
        x.append(groups[j])
    results = pd.DataFrame({'Population':x, 'AR_B_BIFSG':b_bifsg,'AR_NB_BIFSG':nb_bifsg,'Disp_BIFSG':dif_bifsg,'AR_B_BIFSG_plus':b_bifsg_plus,'AR_NB_BIFSG_plus':nb_bifsg_plus,'Disp_BIFSG_plus':dif_bifsg_plus,'AR_B_GT':b_truth,'AR_NB_GT':nb_truth,'Disp_GT':dif_truth})
    return results

def compareRaceModelsGraph(resdat,out,scale100=True):
    indices = range(len(resdat['Population'].unique()))
    width = np.min(np.diff(indices))/3.
    fig = plt.figure()
    ax = fig.add_subplot(111)
    scalar = 100 if scale100 else 1
    ax.scatter(indices-width/2, resdat['AR_B_GT'],color='r',label='GT B',s=64)
    ax.scatter(indices-width/2,resdat['AR_NB_GT'],color='b',label='GT NB',s=64)
    ax.scatter(indices, resdat['AR_B_BIFSG'],color='r',label='BIFSG B',marker='x',s=64)
    ax.scatter(indices, resdat['AR_NB_BIFSG'],color='b',label='BIFSG NB',marker='x',s=64)
    ax.scatter(indices+width/2, resdat['AR_B_BIFSG_plus'],color='r',label='BIFSG+ B',marker='^',s=64)
    ax.scatter(indices+width/2,resdat['AR_NB_BIFSG_plus'],color='b',label='BIFSG+ NB',marker='^',s=64)
    plt.xticks(indices,labels=resdat['Population'].unique(),rotation=60,ha='right')
    plt.xlabel('Population')
    plt.ylabel('Audit Rate')
    plt.legend(title='Estimator')
    plt.savefig(out+'race_model_comparison.png',bbox_inches="tight")
  
def labelbar(ax,rects,nsig=2,fontsize=13,additiveeps=0):
    for rect in rects:
        height = rect.get_height()
        sigfigstr = '%.'+str(nsig)+'f'
        print(sigfigstr)
        ax.text(rect.get_x() + rect.get_width()/2,1.005*height+additiveeps, sigfigstr % (height),ha='center',va='bottom',fontsize=fontsize)

def graphWeightedEstimators(res, out, colb='red', alpb=1, colnb='red', alpnb=1, suffix='', bwidth=0.35, rotation=0, scale100=False):
    fig,ax=plt.subplots(figsize=(10,8))
    if len(res)==1:
        res = res.copy()
        ghost = pd.DataFrame(['',np.nan,np.nan,np.nan,np.nan]).transpose()
        ghost2 = pd.DataFrame(['   ',np.nan,np.nan,np.nan,np.nan]).transpose()
        ghost.columns = res.columns
        ghost2.columns = res.columns
        res = pd.concat([ghost,res,ghost2])
        print(res)
    idx = np.arange(len(res.Population.unique()))
    width=bwidth
    if scale100:
        arb = ax.bar(idx, res['Black Audit Rate']*100, width, yerr=res['Additive Disparity Standard Error']*1.96, color=colb, alpha=alpb, label='Black')
        arnb = ax.bar(idx+width, res['Non-Black Audit Rate']*100, width, yerr=res['Additive Disparity Standard Error']*1.96, color=colnb, alpha=alpnb, label='Non-Black')
    else:
        arb = ax.bar(idx, res['Black Audit Rate'], width, yerr=res['Additive Disparity Standard Error']*1.96, color=colb, alpha=alpb, label='Black')
        arnb = ax.bar(idx+width,res['Non-Black Audit Rate'], width, yerr=res['Additive Disparity Standard Error']*1.96, color=colnb, alpha=alpnb, label='Non-Black')
    ax.set_xticks(idx + width / 2)
    ax.set_xticklabels(res.Population.unique(),rotation=rotation,ha='right')
    ax.set_label('Population')
    ax.set_ylabel('Audit Rate (%)')
    labelbar(ax,arb)
    labelbar(ax,arnb)
    ax.legend()
    ax.axvline(x=1.67, color='light gray')
    ax.axvline(x=3.67, color='light gray')
    if suffix == '_proba_startwEITC':
        ax.text(.49, 5.5, '(1)', fontsize=20)
        ax.text(2.52, 5.5, '(2)', fontsize=20)
        ax.text(4.52, 5.5, '(3)', fontsize=20)
    else:
        ax.text(.49, 6.5, '(1)', fontsize=20)
        ax.text(2.52, 6.5, '(2)', fontsize=20)
        ax.text(4.52, 6.5, '(3)', fontsize=20)
    plt.savefig(out+'Weighted_Estimator_graph'+suffix+'.png',bbox_inches="tight")

def graphWeightedEstimatorsSingleBars(res,out,colb='red',alpb=1,colnb='red',alpnb=1,suffix='',bwidth=0.7,rot=0,numberfontsize=20,labelfontsize=20,nsig=2,axisfontsize=20,aepsl=0,aepsr=0,usexlabel=True,scale100=False):
    fig,ax=plt.subplots(figsize=(10,8))
    idx = np.arange(len(res.Population.unique()))
    width=bwidth
    res_b = res[['Population','Black Audit Rate']]
    res_b['Population'] = 'Black ' + res_b['Population'].str.replace('All','')
    res_nb = res[['Population','Non-Black Audit Rate']]
    res_nb['Population'] = 'Non-Black ' + res_nb['Population'].str.replace('All','')
    res_b.columns = ['Population','Audit Rate']
    res_nb.columns = ['Population','Audit Rate']
    res = pd.concat([res_b,res_nb])
    idx = np.arange(2*len(idx))
    if scale100:
        arb = ax.bar(idx[0:int(len(idx)/2)], res_b['Audit Rate']*100,width,color=colb,alpha=alpb,align='center')
        arnb = ax.bar(idx[int(len(idx)/2):],res_nb['Audit Rate']*100,width,color=colnb,alpha=alpnb,align='center')
    else:
        arb = ax.bar(idx[0:int(len(idx)/2)], res_b['Audit Rate'],width,color=colb,alpha=alpb,align='center')
        arnb = ax.bar(idx[int(len(idx)/2):],res_nb['Audit Rate'],width,color=colnb,alpha=alpnb,align='center')
    ax.set_xticks(idx)
    ax.set_xticklabels(list(res_b.Population.unique())+list(res_nb.Population.unique()),rotation=rot,fontsize=labelfontsize)
    labelbar(ax,arb,nsig,numberfontsize,additiveeps=aepsl)
    labelbar(ax,arnb,nsig,numberfontsize,additiveeps=aepsr)
    if usexlabel:
        ax.set_xlabel('Population', fontsize=axisfontsize)
    ax.set_ylabel('Audit Rate (%)')
    plt.savefig(out+'Weighted_Estimator_graph_single'+suffix+'.png',bbox_inches="tight")





def generalChen(data,weightvars,estimand):
    res = [simpleChen(data,weightvars[i],estimand) for i in range(len(weightvars))]
    if len(weightvars)==2:
         res.append(res[0]-res[1])
    else:
        difs = [[res[i]-res[j] for i in range(len(weightvars))] for j in range(len(weightvars))]
        difs = [item for sublist in difs for item in sublist]
        res.append(max(difs)) 
    return res
def generalChenGroupedHelper(data,weightvars,estimand,groupvar=None):
    if groupvar is None:
        return generalChen(data,weightvars,estimand)
    else:
        uniqueGroups = data[groupvar].unique()
        results = [generalChen(data[data[groupvar]==uniqueGroups[i]],weightvars,estimand) for i in range(len(uniqueGroups))]
        return {uniqueGroups[i]:results[i] for i in range(len(uniqueGroups))}

def getGeneralChenGrouped(data,weightvars,estimand,groupvar=None):
    results = generalChenGroupedHelper(data,weightvars,estimand,groupvar)
    if groupvar is not None:
        cats = list(results.keys())
        res_dict = {}
        res_dict[groupvar] = cats
        for weightvar in weightvars:
            res_dict[estimand+'_'+'Black'] = [results[v][0] for v in cats]
            res_dict[estimand+'_'+'Nonblack'] = [results[v][1] for v in cats]
        return pd.DataFrame(res_dict)


def computeWeightedMean(data,var,wtvar):
    return (data[var]*data[wtvar]).sum()/data[wtvar].sum()
def compute2WeightWeightedMean(data,var,racewt,sampwt):
    return (data[var]*data[sampwt]*data[racewt]).sum()/(data[racewt]*data[sampwt]).sum()

####one thing that is tricky is we don't *really* have independence
####this version takes that into account
def computeModifiedWeightedStandardErrorForDifference(data,var,racewt1='predicted_prob_black',racewt2='predicted_prob_nonblack',samplewt='base_wgt',naive=False): 
    varx = data[var].var()
    totrace1 = (data[racewt1]*data[samplewt]).sum()
    totrace2 = (data[racewt2]*data[samplewt]).sum()
    coeff_term1 = (data[racewt1]*data[samplewt]).pow(2).sum()/(totrace1**2)
    coeff_term2 = (data[racewt2]*data[samplewt]).pow(2).sum()/(totrace2**2)
    crossterm = (data[racewt1]*data[racewt2]*data[samplewt]*data[samplewt]).sum()/(totrace1*totrace2)
    coeff = coeff_term1 + coeff_term2
    print('crossterm: ',crossterm)
    if not naive:
        coeff = coeff -2*crossterm
    return np.sqrt(varx*coeff)

def computeBIFSGCI(data,var,samplewt='base_wgt',racewt1='predicted_prob_black',racewt2='predicted_prob_nonblack',zstat=1.96,naive=False):
    est1 = compute2WeightWeightedMean(data,var,racewt1,samplewt)
    est2 = compute2WeightWeightedMean(data,var,racewt2,samplewt)
    se = computeModifiedWeightedStandardErrorForDifference(data,var=var,racewt1=racewt1,racewt2=racewt2,samplewt=samplewt,naive=naive)
    print('estimated dif: ', est1-est2)
    print('se: ',se)
    return (est1,est2,est1-est2,se, est1-est2-zstat*se, est1-est2+zstat*se)

def computeBIFSGDifAndCIGroup(data,groupvar='adj_bin',var='apm',samplewt='base_wgt',racewt1='predicted_prob_black',racewt2='predicted_prob_nonblack',zstat=1.96,naiveSE=False):
    results = {'Group':[],'Est_B':[],'Est_NB':[],'Dif':[],'SE':[],'LB':[],'UB':[],'TotalWt':[]}
    for g in data[groupvar].unique():
        results['Group'].append(g)
        dat = data[data[groupvar]==g]
        ests = computeBIFSGCI(dat,var=var,samplewt=samplewt,racewt1=racewt1,racewt2=racewt2,naive=naiveSE)
        results['Est_B'].append(ests[0])
        results['Est_NB'].append(ests[1])
        results['Dif'].append(ests[2])
        results['SE'].append(ests[3])
        results['LB'].append(ests[4])
        results['UB'].append(ests[5])
        results['TotalWt'].append(dat[samplewt].sum())
    return(pd.DataFrame(results).sort_values('Group'))




def bootstrapVariance(data,T=100, probvar='predicted_prob_black',auditvarb='audited'):
    results = {'ar_b':[],'ar_nb':[],'dif':[]}
    for t in range(T):
        data['imputed_black'] = np.random.binomial(n=1,p=data[probvar],size=len(data))
        ar_b = (data['imputed_black']*data[auditvarb]).sum()/data['imputed_black'].sum()
        ar_nb = ((1-data['imputed_black'])*data[auditvarb]).sum()/(1-data['imputed_black']).sum()
        dif = ar_b-ar_nb
        results['ar_b'].append(ar_b)
        results['ar_nb'].append(ar_nb)
        results['dif'].append(dif)
    return results
